"""
Phase 5: Current Threshold Cohort Positioning
Development Threshold Paper — Policy Brief Component

Applies historical models to current $9k-$25k countries.
Generates crossing probabilities, demographic windows, risk classification.
"""

import pandas as pd
import numpy as np
import statsmodels.api as sm
from scipy import stats
import sys
sys.path.insert(0, '/mnt/c/demographics_capital_flows/multilateral/src')

# ═══════════════════════════════════════════════════════════════════════
# Load data
# ═══════════════════════════════════════════════════════════════════════
full = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/processed/full_panel.csv')
panel = full[full['year'] <= 2024].copy()
full_proj = full.copy()  # Includes projections

try:
    rents = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/raw/wdi_resource_rents.csv')
    panel = panel.merge(rents[['iso3', 'year', 'resource_rents_gdp']], on=['iso3', 'year'], how='left')
except:
    panel['resource_rents_gdp'] = np.nan

cross_df = pd.read_csv('/mnt/c/demographics_capital_flows/development_threshold/data/crossing_data.csv')

LOWER = 9000
UPPER = 25000

# ═══════════════════════════════════════════════════════════════════════
# 5a. Current snapshot: countries in the zone
# ═══════════════════════════════════════════════════════════════════════
print("=" * 70)
print("5a. CURRENT THRESHOLD COHORT")
print("=" * 70)

# Most recent observation per country
recent = panel[panel['year'] >= 2019].groupby('iso3').last().reset_index()
zone_now = recent[(recent['gdp_pc_ppp'] >= LOWER) & (recent['gdp_pc_ppp'] <= UPPER)].copy()

print(f"\nCountries currently in $9k-$25k zone: {len(zone_now)}")

# Key variables
snapshot_vars = ['gdp_pc_ppp', 'Z_1', 'old_dep', 'youth_dep', 'working_age_share',
                 'kaopen', 'ca_gdp', 'gross_savings_gdp', 'fiscal_bal_gdp',
                 'nfa_gdp', 'resource_rents_gdp', 'rgdp_growth', 'trade_openness',
                 'life_expectancy', 'human_capital']

# ═══════════════════════════════════════════════════════════════════════
# 5b. Demographic projections for cohort
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("5b. DEMOGRAPHIC PROJECTIONS FOR THRESHOLD COHORT")
print("=" * 70)

projection_years = [2030, 2040, 2050]
for yr in projection_years:
    proj = full_proj[full_proj['year'] == yr][['iso3', 'Z_1', 'old_dep', 'youth_dep', 'working_age_share']]
    proj = proj.rename(columns={c: f'{c}_{yr}' for c in proj.columns if c != 'iso3'})
    zone_now = zone_now.merge(proj, on='iso3', how='left')

# Print snapshot
print(f"\n{'Country':6s} {'GDP/cap':>8s} {'Z₁':>6s} {'Z₁_30':>6s} {'Z₁_50':>6s} "
      f"{'OADR':>6s} {'OADR_50':>7s} {'KAO':>5s} {'Sav':>5s} {'Rents':>6s}")
print("-" * 70)
for _, row in zone_now.sort_values('gdp_pc_ppp', ascending=False).iterrows():
    oadr = f"{row['old_dep']*100:.0f}%" if pd.notna(row.get('old_dep')) else '—'
    oadr50 = f"{row.get(f'old_dep_2050', np.nan)*100:.0f}%" if pd.notna(row.get(f'old_dep_2050')) else '—'
    z1_30 = f"{row.get(f'Z_1_2030', np.nan):.2f}" if pd.notna(row.get(f'Z_1_2030')) else '—'
    z1_50 = f"{row.get(f'Z_1_2050', np.nan):.2f}" if pd.notna(row.get(f'Z_1_2050')) else '—'
    sav = f"{row.get('gross_savings_gdp', np.nan):.0f}" if pd.notna(row.get('gross_savings_gdp')) else '—'
    rents_str = f"{row.get('resource_rents_gdp', np.nan):.1f}" if pd.notna(row.get('resource_rents_gdp')) else '—'
    kao = f"{row.get('kaopen', np.nan):.1f}" if pd.notna(row.get('kaopen')) else '—'
    print(f"{row['iso3']:6s} {row['gdp_pc_ppp']:8,.0f} {row['Z_1']:6.2f} {z1_30:>6s} {z1_50:>6s} "
          f"{oadr:>6s} {oadr50:>7s} {kao:>5s} {sav:>5s} {rents_str:>6s}")

# ═══════════════════════════════════════════════════════════════════════
# 5c. Predicted probability of crossing
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("5c. PREDICTED PROBABILITY OF CROSSING")
print("=" * 70)

# Fit logit on historical zone entrants
zone_entrants = cross_df[cross_df['cross_9k'].notna() & (cross_df['status'] != 'Always above')].copy()

logit_vars = ['Z_1_entry', 'gdp_pc_ppp_entry', 'kaopen_entry', 'gross_savings_gdp_entry']
logit_sample = zone_entrants[['exited_above'] + logit_vars].dropna()

if len(logit_sample) >= 20:
    y = logit_sample['exited_above']
    X = sm.add_constant(logit_sample[logit_vars])
    logit_model = sm.Logit(y, X).fit(disp=0)
    print(f"Logit model: n={len(logit_sample)}, pseudo-R²={logit_model.prsquared:.3f}")
    for var in logit_vars:
        sig = '***' if logit_model.pvalues[var] < 0.01 else '**' if logit_model.pvalues[var] < 0.05 else '*' if logit_model.pvalues[var] < 0.1 else ''
        print(f"  {var}: coef={logit_model.params[var]:.4f}, p={logit_model.pvalues[var]:.4f}{sig}")

    # Predict for current cohort
    # Map current variables to entry variables
    zone_now['Z_1_entry'] = zone_now['Z_1']
    zone_now['gdp_pc_ppp_entry'] = zone_now['gdp_pc_ppp']
    zone_now['kaopen_entry'] = zone_now['kaopen']
    zone_now['gross_savings_gdp_entry'] = zone_now['gross_savings_gdp']

    pred_sample = zone_now[logit_vars].dropna()
    if len(pred_sample) > 0:
        X_pred = sm.add_constant(pred_sample)
        zone_now.loc[pred_sample.index, 'p_cross'] = logit_model.predict(X_pred)

    # Also fit simpler model (Z₁ + GDP only) for countries with missing data
    simple_vars = ['Z_1_entry', 'gdp_pc_ppp_entry']
    simple_sample = zone_entrants[['exited_above'] + simple_vars].dropna()
    simple_model = sm.Logit(simple_sample['exited_above'],
                            sm.add_constant(simple_sample[simple_vars])).fit(disp=0)

    missing_pred = zone_now[zone_now['p_cross'].isna()]
    simple_pred_sample = missing_pred[simple_vars].dropna()
    if len(simple_pred_sample) > 0:
        zone_now.loc[simple_pred_sample.index, 'p_cross'] = simple_model.predict(
            sm.add_constant(simple_pred_sample))

    print(f"\nPredicted crossing probabilities for current cohort:")
    print(f"{'Country':6s} {'GDP/cap':>8s} {'Z₁':>6s} {'P(cross)':>9s}")
    print("-" * 35)
    for _, row in zone_now.sort_values('p_cross', ascending=False).iterrows():
        p = f"{row['p_cross']:.2f}" if pd.notna(row.get('p_cross')) else '—'
        print(f"{row['iso3']:6s} {row['gdp_pc_ppp']:8,.0f} {row['Z_1']:6.2f} {p:>9s}")

# ═══════════════════════════════════════════════════════════════════════
# 5d. Demographic window analysis
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("5d. DEMOGRAPHIC WINDOW: YEARS OF FAVORABLE DEMOGRAPHICS")
print("=" * 70)

# Key insight: crossers had mean Z₁ at exit. Countries need to cross before aging past that.
crossers_at_exit = cross_df[cross_df['exited_above'] == 1]['Z_1_exit'].dropna()
z1_threshold = crossers_at_exit.mean()
z1_p75 = crossers_at_exit.quantile(0.75)

print(f"Z₁ at exit for historical crossers: mean={z1_threshold:.3f}, p75={z1_p75:.3f}")

# For each current cohort country, find when Z₁ hits these thresholds
for _, row in zone_now.iterrows():
    iso = row['iso3']
    proj = full_proj[(full_proj['iso3'] == iso) & (full_proj['year'] >= 2024) &
                     (full_proj['Z_1'].notna())].sort_values('year')

    # Year Z₁ exceeds crosser mean
    above_thresh = proj[proj['Z_1'] > z1_threshold]
    if len(above_thresh) > 0:
        zone_now.loc[zone_now['iso3'] == iso, 'window_closes'] = above_thresh['year'].min()
        zone_now.loc[zone_now['iso3'] == iso, 'years_remaining'] = above_thresh['year'].min() - 2024
    else:
        zone_now.loc[zone_now['iso3'] == iso, 'window_closes'] = 2100
        zone_now.loc[zone_now['iso3'] == iso, 'years_remaining'] = 76  # Beyond projection horizon

    # Year Z₁ exceeds crosser p75 (generous window)
    above_p75 = proj[proj['Z_1'] > z1_p75]
    if len(above_p75) > 0:
        zone_now.loc[zone_now['iso3'] == iso, 'generous_window'] = above_p75['year'].min() - 2024

print(f"\n{'Country':6s} {'GDP/cap':>8s} {'Z₁ now':>7s} {'Z₁ 2050':>7s} {'Window closes':>14s} {'Years left':>11s}")
print("-" * 60)
for _, row in zone_now.sort_values('years_remaining').iterrows():
    z1_50 = f"{row.get(f'Z_1_2050', np.nan):.2f}" if pd.notna(row.get(f'Z_1_2050')) else '—'
    closes = f"{int(row['window_closes'])}" if pd.notna(row.get('window_closes')) else '—'
    yrs = f"{row['years_remaining']:.0f}" if pd.notna(row.get('years_remaining')) else '—'
    print(f"{row['iso3']:6s} {row['gdp_pc_ppp']:8,.0f} {row['Z_1']:7.2f} {z1_50:>7s} {closes:>14s} {yrs:>11s}")

# ═══════════════════════════════════════════════════════════════════════
# 5e. Risk classification: traffic light
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("5e. RISK CLASSIFICATION (TRAFFIC LIGHT)")
print("=" * 70)

def classify_risk(row):
    """
    Green: P(cross) > 0.5 AND years_remaining > 15
    Amber: P(cross) 0.2-0.5 OR years_remaining 5-15
    Red: P(cross) < 0.2 OR years_remaining < 5
    """
    p = row.get('p_cross', np.nan)
    yrs = row.get('years_remaining', np.nan)

    if pd.isna(p) and pd.isna(yrs):
        return 'Insufficient data'

    # Red conditions
    if (pd.notna(p) and p < 0.2) or (pd.notna(yrs) and yrs < 5):
        return 'RED'
    # Green conditions
    if (pd.notna(p) and p > 0.5) and (pd.notna(yrs) and yrs > 15):
        return 'GREEN'
    # Amber: everything else
    return 'AMBER'

zone_now['risk'] = zone_now.apply(classify_risk, axis=1)

# Add commodity double-cliff flag
zone_now['commodity_double_cliff'] = (
    (zone_now['resource_rents_gdp'].fillna(0) >= 10) &
    (zone_now['years_remaining'].fillna(100) < 20)
)

print(f"\nRisk classification:")
for risk in ['GREEN', 'AMBER', 'RED', 'Insufficient data']:
    sub = zone_now[zone_now['risk'] == risk]
    if len(sub) > 0:
        print(f"\n  {risk}: {len(sub)} countries")
        for _, row in sub.sort_values('gdp_pc_ppp', ascending=False).iterrows():
            p_str = f"{row['p_cross']:.2f}" if pd.notna(row.get('p_cross')) else '—'
            yrs_str = f"{row['years_remaining']:.0f}yr" if pd.notna(row.get('years_remaining')) else '—'
            dc = ' [DOUBLE CLIFF]' if row.get('commodity_double_cliff') else ''
            print(f"    {row['iso3']:6s} GDP={row['gdp_pc_ppp']:,.0f}  P(cross)={p_str}  Window={yrs_str}{dc}")

# ═══════════════════════════════════════════════════════════════════════
# 5f. Historical analogues
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("5f. HISTORICAL ANALOGUES")
print("=" * 70)

# For each current cohort country, find closest historical country at same GDP level + Z₁
historical = cross_df[cross_df['cross_9k'].notna() & (cross_df['status'] != 'Always above')].copy()
historical = historical[historical['gdp_pc_ppp_entry'].notna() & historical['Z_1_entry'].notna()]

# Standardize for distance calculation
gdp_std = historical['gdp_pc_ppp_entry'].std()
z1_std = historical['Z_1_entry'].std()

if gdp_std > 0 and z1_std > 0 and len(historical) > 0:
    print(f"\n{'Current':6s} → {'Analogue':8s} {'Status':20s} {'GDP dist':>9s} {'Z₁ dist':>8s} {'Outcome':>8s}")
    print("-" * 70)

    analogue_data = []
    for _, curr in zone_now.iterrows():
        if pd.isna(curr['Z_1']) or pd.isna(curr['gdp_pc_ppp']):
            continue

        # Euclidean distance in standardized space
        historical['dist'] = np.sqrt(
            ((historical['gdp_pc_ppp_entry'] - curr['gdp_pc_ppp']) / gdp_std) ** 2 +
            ((historical['Z_1_entry'] - curr['Z_1']) / z1_std) ** 2
        )

        closest = historical.nsmallest(3, 'dist')
        for _, match in closest.head(1).iterrows():
            outcome = 'CROSSED' if match['exited_above'] == 1 else 'did not'
            gdp_d = f"${abs(match['gdp_pc_ppp_entry'] - curr['gdp_pc_ppp']):,.0f}"
            z1_d = f"{abs(match['Z_1_entry'] - curr['Z_1']):.3f}"
            print(f"{curr['iso3']:6s} → {match['iso3']:8s} ({int(match.get('cross_9k', 0))})  "
                  f"{gdp_d:>12s} {z1_d:>8s}  {outcome}")

            analogue_data.append({
                'current_iso3': curr['iso3'],
                'analogue_iso3': match['iso3'],
                'analogue_entry_year': match.get('cross_9k'),
                'analogue_outcome': outcome,
                'distance': match['dist'],
            })

    pd.DataFrame(analogue_data).to_csv(
        '/mnt/c/demographics_capital_flows/development_threshold/output/tables/table14_analogues.csv', index=False)

# ═══════════════════════════════════════════════════════════════════════
# 5g. Save cohort data
# ═══════════════════════════════════════════════════════════════════════
output_cols = ['iso3', 'gdp_pc_ppp', 'Z_1', 'old_dep', 'youth_dep', 'working_age_share',
               'kaopen', 'ca_gdp', 'gross_savings_gdp', 'fiscal_bal_gdp', 'nfa_gdp',
               'resource_rents_gdp', 'rgdp_growth', 'trade_openness',
               'p_cross', 'years_remaining', 'window_closes', 'risk', 'commodity_double_cliff']

# Add projection columns
for yr in projection_years:
    for var in ['Z_1', 'old_dep']:
        col = f'{var}_{yr}'
        if col in zone_now.columns:
            output_cols.append(col)

available_cols = [c for c in output_cols if c in zone_now.columns]
zone_now[available_cols].to_csv(
    '/mnt/c/demographics_capital_flows/development_threshold/output/tables/table12_current_cohort.csv', index=False)

# Summary statistics
print(f"\n\nSUMMARY:")
print(f"  Countries in zone: {len(zone_now)}")
print(f"  GREEN: {(zone_now['risk']=='GREEN').sum()}")
print(f"  AMBER: {(zone_now['risk']=='AMBER').sum()}")
print(f"  RED: {(zone_now['risk']=='RED').sum()}")
print(f"  Commodity double cliff: {zone_now['commodity_double_cliff'].sum()}")
print(f"  Median P(cross): {zone_now['p_cross'].median():.2f}" if zone_now['p_cross'].notna().any() else "")
print(f"  Median window: {zone_now['years_remaining'].median():.0f} years" if zone_now['years_remaining'].notna().any() else "")

print("\n✓ Phase 5 complete.")
